import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from itertools import product
from sqlalchemy.dialects.mssql.information_schema import columns


from PolicyLearning.Utility import LoadData2, GetRiskGain,  CD_234_shared2, GetRiskGain2, TPSort

torch.set_default_tensor_type('torch.DoubleTensor')

# In this file, all scores are coded in 0-4

# create policy
table2_array = np.load('Database/table2_array.npy') - 1
table3_array = np.load('Database/table3_array.npy') - 1
table2_array_bin = np.zeros((25,5), dtype=int)
table3_array_bin = np.zeros((125,5), dtype=int)
table2_cur = np.array(table2_array).reshape((5,5))
table3_cur = np.array(table3_array).reshape((5,5,5))
for i in range(25):
    table2_array_bin[i][table2_array[i]] = 1
for i in range(125):
    table3_array_bin[i][table3_array[i]] = 1


DataName = 'Data1969_09'
OutcomeCol1 = 'sec_c'
OutcomeCol2 = 'econ_c'
OutcomeCol3 = 'soccap_c'
# LoadData
dt = LoadData2(DataName)
Safety_Score = dt.loc[:,'my4'].values - 1
temp1 = np.array([0,1,2,3,4], dtype=int)
temp4 = np.array([25*i+5*j+k for i,j,k in product(range(5),range(5),range(5))])

R = 4000
Matern_length_scale = 1
Matern_nu = 1.5
sig_latent = 2
sig_noise = 1
myfilename1 = '_'.join([DataName,OutcomeCol1,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
myfilename2 = '_'.join([DataName,OutcomeCol2,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
myfilename3 = '_'.join([DataName,OutcomeCol3,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
TotalResult1 = torch.load("Database/MCMCSample/"+myfilename1,map_location=torch.device('cpu'))[:,2000:,:] # 5*R*N tensor
TotalResult2 = torch.load("Database/MCMCSample/"+myfilename2,map_location=torch.device('cpu'))[:,2000:,:] # 5*R*N tensor
TotalResult3 = torch.load("Database/MCMCSample/"+myfilename3,map_location=torch.device('cpu'))[:,2000:,:] # 5*R*N tensor


# Specify the outcome variable
TotalResult = TotalResult1
Risk_table, Gain_table = GetRiskGain2(TotalResult)
X4_input = pd.DataFrame(dt.loc[:,['3a','3b','3c']].values-1, columns=['V1','V2','V3'])
X3a_input = pd.DataFrame(dt.loc[:,['2a','2b','mod1c_num']].values-1, columns=['V1','V2','V3'])
X3b_input = pd.DataFrame(dt.loc[:,['2d','2c','mod1m_num']].values-1, columns=['V1','V2','V3'])
X3c_input = pd.DataFrame(dt.loc[:,['2e','2f']].values-1, columns=['V1','V2'])
X2a_input = pd.DataFrame(dt.loc[:,['mod1a_num','mod1b_num']].values-1, columns=['V1','V2'])
X2b_input = pd.DataFrame(dt.loc[:,['mod1e_num','mod1d_num', '1f1g']].values-1, columns=['V1','V2','V3'])
X2c_input = pd.DataFrame(dt.loc[:,['mod1h_num','mod1i_num']].values-1, columns=['V1','V2'])
X2d_input = pd.DataFrame(dt.loc[:,['mod1j_num','1l1k', 'mod1g_num']].values-1, columns=['V1','V2','V3'])
X2e_input = pd.DataFrame(dt.loc[:,['mod1o_num','mod1n_num', 'mod1p_num']].values-1, columns=['V1','V2','V3'])
X2f_input = pd.DataFrame(dt.loc[:,['mod1r_num','mod1q_num', 'mod1s_num']].values-1, columns=['V1','V2','V3'])

# Sanity check
if sum([(table3_array[X2b_input.iloc[:,0]*25+X2b_input.iloc[:,1]*5+X2b_input.iloc[:,2]]!=dt['2b']-1).sum(),
(table3_array[X2d_input.iloc[:,0]*25+X2d_input.iloc[:,1]*5+X2d_input.iloc[:,2]]!=dt['2d']-1).sum(),
(table3_array[X2e_input.iloc[:,0]*25+X2e_input.iloc[:,1]*5+X2e_input.iloc[:,2]]!=dt['2e']-1).sum(),
(table3_array[X2f_input.iloc[:,0]*25+X2f_input.iloc[:,1]*5+X2f_input.iloc[:,2]]!=dt['2f']-1).sum(),
(table3_array[X3a_input.iloc[:,0]*25+X3a_input.iloc[:,1]*5+X3a_input.iloc[:,2]]!=dt['3a']-1).sum(),
(table3_array[X3b_input.iloc[:,0]*25+X3b_input.iloc[:,1]*5+X3b_input.iloc[:,2]]!=dt['3b']-1).sum(),
(table3_array[X4_input.iloc[:,0]*25+X4_input.iloc[:,1]*5+X4_input.iloc[:,2]]!=dt['my4']-1).sum(),
(table2_array[X2a_input.iloc[:,0]*5+X2a_input.iloc[:,1]]!=dt['2a']-1).sum() ,
(table2_array[X2c_input.iloc[:,0]*5+X2c_input.iloc[:,1]]!=dt['2c']-1).sum() ,
(table2_array[X3c_input.iloc[:,0]*5+X3c_input.iloc[:,1]]!=dt['3c']-1).sum() ]) ==0:
    print('Passed sanity check')
else:
    print("Didn't pass the check!!!!!!!!!!!!!!!!!!!!!!!!")

def onerun(Epsilon):
    TPs = TPSort(
        [X4_input, X3a_input, X3b_input, X3c_input, X2a_input, X2b_input, X2c_input, X2d_input, X2e_input, X2f_input],
        Risk_table, Gain_table, 4000, Epsilon, 0.0)
    TPs.ShortBurstMCMCOptim(Itenum=400, Burst_size=18, regularization=0.0)

    return TPs.ComputeRiskGain(), TPs.PDP(dt.iloc[:, :19].values - 1)

results = list(Parallel(n_jobs=18)(delayed(onerun)(0.1) for i in range(100)))

runningreuslts = results

riskgainlist = []
pdp_importance_list = []
for i,j in runningreuslts:
    riskgainlist.append(i)
    pdp_importance_list.append(j.std(axis=1))
riskgainlist = np.array(riskgainlist)
pdp_importance_list = np.array(pdp_importance_list)
pdp_new = pdp_importance_list[riskgainlist[:,1].argsort()[-10:][::-5]].mean(axis=0)
TP0 = PL234_CD_shared2 = CD_234_shared2(
        [X4_input, X3a_input, X3b_input, X3c_input, X2a_input, X2b_input, X2c_input, X2d_input, X2e_input, X2f_input],
        Risk_table, Gain_table, 4000, 0.1, 0.01)
pdp_old = PL234_CD_shared2.PDP(dt.iloc[:, :19].values - 1).std(axis=1)
pdp_old_scale = pdp_old/pdp_old.sum()
pdp_new_scale = pdp_new/pdp_new.sum()
mytitle1 = 'Regional safety as the outcome'
mytitle2 = 'Regional economy as the outcome'
mytitle3 = 'Regional civic society as the outcome'
Epsilon = np.flip(np.array([np.exp(-2.3-i*0.1) for i in range(30)]))
PDP_table = np.empty((30,19))
for ite,epsilon in enumerate(Epsilon):
    results = list(Parallel(n_jobs=18)(delayed(onerun)(epsilon) for i in range(10)))
    runningreuslts = results
    riskgainlist = []
    pdp_importance_list = []
    for i, j in runningreuslts:
        riskgainlist.append(i)
        pdp_importance_list.append(j.std(axis=1))
    riskgainlist = np.array(riskgainlist)
    pdp_importance_list = np.array(pdp_importance_list)
    pdp_new = pdp_importance_list[riskgainlist[:, 1].argsort()[-10:][::-5]].mean(axis=0)
    pdp_new_scale = pdp_new / pdp_new.sum()
    PDP_table[ite,:] = pdp_new_scale

plt.plot(Epsilon,PDP_table)

plt.show()

PDP_table[:,0]

Allnames = np.array(["Enemy Military Presence"
                        , "Enemy Military Activity"
                        , "Impact of Military Activity"
                        , "Friendly Military Presence"
                        , "Friendly Military Activity"
                        , "Law Enforcement"
                        , "PSDF Activity"
                        , "Enemy Political Presence"
                        , "Enemy Political Activity"
                        , "Administration"
                        , "RD Cadre"
                        , "Information PSYOPS"
                        , "Political Mobilization"
                        , "Public Health"
                        , "Education"
                        , "Social Welfare"
                        , "Development Assistance"
                        , "Economic Activity"
                        , "Land Tenure"])
NameCategory = [["Enemy Military Presence"
                    , "Enemy Military Activity"
                    , "Impact of Military Activity"
                    , "Friendly Military Presence"
                    , "Friendly Military Activity"
                    , "Law Enforcement"
                    , "PSDF Activity"], ["PSDF Activity"
                    , "Enemy Political Presence"
                    , "Enemy Political Activity"
                    , "Administration"
                    , "RD Cadre"
                    , "Information PSYOPS"
                    , "Political Mobilization"], ["Public Health"
                    , "Education"
                    , "Social Welfare"
                    , "Development Assistance"
                    , "Economic Activity"
                    , "Land Tenure"]]

fontsize = 18
linewidth = 4
marksize = 50
colorset = ['red', 'blue', 'green']
colormeaning = ['submodels on security', 'submodels on politics', 'submodels on socioeconomics']
fig, ax = plt.subplots(figsize=(17, 8))

for i in range(19):
    if Allnames[i] in NameCategory[0]:
        mycolor = colorset[0]
        mymeaning = colormeaning[0]
    elif Allnames[i] in NameCategory[1]:
        mycolor = colorset[1]
        mymeaning = colormeaning[1]
    else:
        mycolor = colorset[2]
        mymeaning = colormeaning[2]
    ax.plot(Epsilon, PDP_table[:,i], color=mycolor, label=mymeaning, linewidth=linewidth)
fig.show()

def plot_relative_importance_numerical(mytitle):

    Allnames = np.array(["Enemy Military Presence"
    ,	 "Enemy Military Activity"
    ,	 "Impact of Military Activity"
    ,	 "Friendly Military Presence"
    , "Friendly Military Activity"
    ,	 "Law Enforcement"
    ,	 "PSDF Activity"
    ,	 "Enemy Political Presence"
    ,	 "Enemy Political Activity"
    ,	 "Administration"
    ,	 "RD Cadre"
    ,	 "Information PSYOPS"
    ,	 "Political Mobilization"
    ,	 "Public Health"
    ,	 "Education"
    ,	 "Social Welfare"
    ,	 "Development Assistance"
    ,	 "Economic Activity"
    ,	 "Land Tenure"])
    NameCategory = [["Enemy Military Presence"
    ,	 "Enemy Military Activity"
    ,	 "Impact of Military Activity"
    ,	 "Friendly Military Presence"
    , "Friendly Military Activity"
    ,	 "Law Enforcement"
    ,	 "PSDF Activity"],[ "PSDF Activity"
    ,	 "Enemy Political Presence"
    ,	 "Enemy Political Activity"
    ,	 "Administration"
    ,	 "RD Cadre"
    ,	 "Information PSYOPS"
    ,	 "Political Mobilization"],["Public Health"
    ,	 "Education"
    ,	 "Social Welfare"
    ,	 "Development Assistance"
    ,	 "Economic Activity"
    ,	 "Land Tenure"]]
    fontsize = 18
    linewidth = 4
    marksize = 50
    colorset = ['red','blue','green']
    colormeaning = ['submodels on security','submodels on politics','submodels on socioeconomics']
    fig, ax = plt.subplots(figsize=(17, 8))
    ax.plot([0.01,0.15],[6.5,6.5], linewidth=linewidth/1.5, c='black', linestyle='--')
    ax.plot([0.01, 0.15], [12.5, 12.5], linewidth=linewidth / 1.5, c='black', linestyle='--')
    ax.text(0.1, 13, 'Socioeconomical submodel scores',c='g', fontsize=fontsize)
    ax.text(0.1, 7, 'Political submodel scores',c='b', fontsize=fontsize)
    ax.text(0.1, 5, 'Military submodel scores',c='r', fontsize=fontsize)
    plt.subplots_adjust(left=0.2, right=0.95, hspace=0.3)
    ax.set_ylim((-1,19))
    ax.set_xlabel('Scaled PDP importance in the decision rule', fontsize=fontsize)
    ax.set_yticks(np.arange(0,19),Allnames)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    for i in range(19):
        if Allnames[i] in NameCategory[0]:
            mycolor = colorset[0]
            mymeaning = colormeaning[0]
        elif Allnames[i] in NameCategory[1]:
            mycolor = colorset[1]
            mymeaning = colormeaning[1]
        else:
            mycolor = colorset[2]
            mymeaning = colormeaning[2]
        ax.plot([pdp_old_scale[i],pdp_new_scale[i]],[i,i], color=mycolor, label=mymeaning, linewidth=linewidth)
        if abs(pdp_new_scale[i]-pdp_old_scale[i]) > 0.0:
            ax.annotate('',xytext=(pdp_old_scale[i],i),xy=(pdp_new_scale[i],i),arrowprops=dict( lw=linewidth/2,arrowstyle="->", color=mycolor),
            size=marksize*(abs(pdp_new_scale[i]-pdp_old_scale[i])*100+1.5)/3,)
    ax.legend(fontsize=fontsize,loc='upper right')
    handles, labels = plt.gca().get_legend_handles_labels()
    unique_handles, unique_labels = [], []
    for handle, label in zip(handles, labels):
        if label not in unique_labels:
            unique_handles.append(handle)
            unique_labels.append(label)
    plt.legend(unique_handles, unique_labels,fontsize=fontsize/1.2, loc='upper right')
    fig.tight_layout()
    fig.show()
    fig.savefig('Visualization/PDP/'+mytitle.replace(" ", "_")+'_scaled.pdf')

Epsilon = np.flip(np.array([np.exp(-2.3 - i * 0.1) for i in range(30)]))

def simu_3(epsilon, R):
    results = list(Parallel(n_jobs=18)(delayed(onerun)(0.1) for i in range(R)))
    runningreuslts = results
    riskgainlist = []
    pdp_importance_list = []
    for i,j in runningreuslts:
        riskgainlist.append(i)
        pdp_importance_list.append(j.std(axis=1))
    riskgainlist = np.array(riskgainlist)
    pdp_importance_list = np.array(pdp_importance_list)
    # pdp_new = pdp_importance_list[riskgainlist[:,1].argmax()]
    pdp_new = pdp_importance_list[riskgainlist[:,1].argsort()[-10:][::-5]].mean(axis=0)
    # pdp_new = np.array(pdp_importance_list).mean(axis=0)
    TP0 = PL234_CD_shared2 = CD_234_shared2(
            [X4_input, X3a_input, X3b_input, X3c_input, X2a_input, X2b_input, X2c_input, X2d_input, X2e_input, X2f_input],
            Risk_table, Gain_table, 4000, 0.1, 0.01)
    pdp_old = PL234_CD_shared2.PDP(dt.iloc[:, :19].values - 1).std(axis=1)
    pdp_old_scale = pdp_old/pdp_old.sum()
    pdp_new_scale = pdp_new/pdp_new.sum()

